{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "#加载mnist_inference 和 mnist_train 中定义的常量和函数\n",
    "import mnist_inference\n",
    "import mnist_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#每10秒加载一次罪行的模型，并在测试数据上测试最新模型的正确率。\n",
    "EVAL_INTERVAL_SECS = 10\n",
    "def evaluate(mnist):\n",
    "    with tf.Graph().as_default() as g:\n",
    "        # 定义输入输出的格式。\n",
    "        x = tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input')\n",
    "        y_ = tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input')\n",
    "        validate_feed = {x:mnist.validation.images,y_:mnist.validation.labels}\n",
    "        #直接调用封装好的函数来计算前向传播的结果。因为测试时不关注正则化损失的值，\n",
    "        #所以这里用于计算正则化损失的函数被设置为None\n",
    "        y = mnist_inference.inference(x,None)\n",
    "        # 使用前向传播的结果计算正确率。如果需要对未知样例进行分类。那么使用 tf.argmax(y,1) \n",
    "        #就可以得到输入样例的预测类别了\n",
    "        correct_prediction = tf.equal(tf.argmax(y,1),tf.arg_max(y_,1))\n",
    "        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))\n",
    "        #通过变量重命名的方式来加载模型，这样在前向传播的过程就不需要调用求滑动平均的函数来获取平均值了。\n",
    "        #这样就可以完全共用mnist_inferece.py中定义的前向传播过程\n",
    "        variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)\n",
    "        variables_to_restore = variable_averages.variables_to_restore()\n",
    "        saver = tf.train.Saver(variables_to_restore)\n",
    "        #每隔EVAL_INTERVAL_SECS 秒调用一次计算正确率的过程以检测训练过程中正确率的变化\n",
    "        while True:\n",
    "            with tf.Session as sess:\n",
    "                #tf.train.get_checkpoint_state 函数会通过checkpoint 文件自动找到目录中最新模型的文件名\n",
    "                ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)\n",
    "                if ckpt and ckpt.model_checkpoint_path:\n",
    "                    # 加载模型\n",
    "                    saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "                    # 通过文件名得到模型保存时迭代得轮数\n",
    "                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]\n",
    "                    accuracy_score = sess.run(accuracy,feed_dict=validate_feed)\n",
    "                    print(\"After %s trainning step(s), validation accuracy = %g\" % (global_step,accuracy_score))\n",
    "                else:\n",
    "                    print('No checkpoint file found')\n",
    "                    return\n",
    "                time.sleep(EVAL_INTERVAL_SECS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./tmp/data\\train-images-idx3-ubyte.gz\n",
      "Extracting ./tmp/data\\train-labels-idx1-ubyte.gz\n",
      "Extracting ./tmp/data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ./tmp/data\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "__enter__",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-7-81853d815133>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m     \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\tool\\Miniconda3\\lib\\site-packages\\tensorflow\\python\\platform\\app.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(main, argv)\u001b[0m\n\u001b[0;32m     46\u001b[0m   \u001b[1;31m# Call the main function, passing through any arguments\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m   \u001b[1;31m# to the final program.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m   \u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margv\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mflags_passthrough\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-7-81853d815133>\u001b[0m in \u001b[0;36mmain\u001b[1;34m(argv)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margv\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m     \u001b[0mmnist\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"./tmp/data\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mone_hot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m     \u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-5-c458b3fbd3ca>\u001b[0m in \u001b[0;36mevaluate\u001b[1;34m(mnist)\u001b[0m\n\u001b[0;32m     21\u001b[0m         \u001b[1;31m#每隔EVAL_INTERVAL_SECS 秒调用一次计算正确率的过程以检测训练过程中正确率的变化\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m             \u001b[1;32mwith\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSession\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0msess\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     24\u001b[0m                 \u001b[1;31m#tf.train.get_checkpoint_state 函数会通过checkpoint 文件自动找到目录中最新模型的文件名\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m                 \u001b[0mckpt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_checkpoint_state\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist_train\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMODEL_SAVE_PATH\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: __enter__"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"./tmp/data\",one_hot=True)\n",
    "    evaluate(mnist)\n",
    "if __name__ == '__main__':\n",
    "    tf.app.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
